library(tidyverse)
library(RColorBrewer)
skyblue <- rgb(236, 240, 241, max = 255)

covered <- function(x, se, mu, q_z) {
  ci_hi <- x + q_z * se
  ci_lo <- x - q_z * se
  as.numeric(ci_hi > mu & ci_lo < mu)
}


col_spec <- cols(.default = col_double(), dgp = col_character())
fn <- list.files("output/", pattern = "^binary.*.csv")
out <- lapply(
  fn,
  function(x) read_csv(paste0("output/", x), col_types = col_spec)
)
out <- bind_rows(out)
out <- out %>%
  mutate(
    q_95 = qt(p = 0.975, n_units - n_covs),
    lasso_cov = covered(lasso_int, lasso_intse, 0.25, q_95),
    pds_cov = covered(pds_int, pds_intse, 0.25, q_95),
    oracle_cov = covered(oracle_int, oracle_intse, 0.25, q_95),
    fully_cov = covered(fully_int, fully_intse, 0.25, q_95),
    single_cov = covered(single_int, single_intse, 0.25, q_95)
  ) %>%
  pivot_longer(cols = oracle_main:single_cov,
               names_to = c("method", "qoi"),
               names_sep = "_")


## Absolute Bias
res <- out %>%
  filter(qoi == "int") %>%
  group_by(dgp, n_covs, R2_y, R2_d, method) %>%
  summarize(bias = abs(mean(value - 0.25)),
            rmse = sqrt(mean((value - 0.25)^2)),
            se = sd(value))
res

res_se <- out %>%
  filter(qoi == "intse")  %>%
  group_by(dgp, n_covs, R2_y, R2_d, method) %>%
  summarize(seavg = mean(value))
res_se
res_se <- merge(res, res_se)
res_se$se_bias <-  res_se$seavg - res_se$se

res_cov <- out %>%
  filter(qoi == "cov") %>%
  group_by(dgp, n_covs, R2_y, R2_d, method) %>%
  summarize(coverage = mean(value))
res_cov


## themes
theme_set(theme_minimal() +
            theme(panel.border = element_rect(fill = NA)))
theme_cousteau <- theme_light() +
  theme(plot.background = element_rect(fill = skyblue, color = skyblue),
        legend.background = element_rect(fill = skyblue))

meth_labs <- c(fully = "Fully Moderated",
               lasso = "Post-Lasso",
               oracle = "Oracle",
               pds = "Post-Double Selection",
               single = "Single Interaction")
hues <- seq(15, 375, length = 7 + 1)
cols <-  hcl(h = hues, l = 65, c = 100)[c(1, 2, 3, 4, 7)]
names(cols) <- c("pds", "lasso", "oracle", "fully", "single")

shapes <- c(15, 18, 3, 17, 9)
names(shapes) <- c("pds", "lasso", "oracle", "fully",  "single")


facet_labs <- label_bquote(rows = .(n_covs)~"Covariates",
                           cols = {R^2}[y] == .(R2_y))
xlab <- expression(
  paste(
    "Partial ", R ^ 2, " for ", X[i] * V[i], " interaction on ", D[i]
  )
)

base_meths <- c("lasso", "fully", "oracle", "pds")
base_labs <- meth_labs[base_meths]
base_cols <- cols[base_meths]
base_shapes <- shapes[base_meths]

all_meths <- c("lasso", "fully", "oracle", "pds", "single")
all_labs <- meth_labs[all_meths]
all_cols <- cols[all_meths]
all_shapes <- shapes[all_meths]


rmse_base <- res %>%
  filter(dgp == "binary") %>%
  filter(method %in% base_meths) %>%
  ggplot(aes(x = R2_d, y = rmse, group = method, color = method,
             shape = method)) +
  labs(title = "Root Mean Square Error",
       subtitle = "Binary Outcome with Logistic Lasso Models",
       x = xlab, y = "RMSE",
       color = "Method", shape = "Method") +
  facet_grid(n_covs ~ R2_y, labeller = facet_labs) +
  scale_color_manual(labels = base_labs, values = base_cols) +
  scale_shape_manual(labels = base_labs, values = base_shapes) +
  theme(axis.title.y = element_text(angle = 0, vjust = 0.5, hjust = 0.5))

rmse_base_all <- res %>%
  filter(dgp == "binary") %>%
  filter(method %in% all_meths) %>%
  ggplot(aes(x = R2_d, y = rmse, group = method, color = method,
             shape = method)) +
  labs(title = "Root Mean Square Error",
       subtitle = "Binary Outcome with Logistic Lasso Models",
       x = xlab, y = "RMSE",
       color = "Method", shape = "Method") +
  facet_grid(n_covs ~ R2_y, labeller = facet_labs) +
  scale_color_manual(labels = all_labs, values = all_cols) +
  scale_shape_manual(labels = all_labs, values = all_shapes) +
  theme(axis.title.y = element_text(angle = 0, vjust = 0.5, hjust = 0.5))

bias_base <- res %>%
  filter(dgp == "binary") %>%
  filter(method %in% base_meths) %>%
  ggplot(aes(x = R2_d, y = bias, group = method, color = method,
             shape = method)) +
  labs(title = "Absolute Bias",
       subtitle = "Binary Outcome with Logistic Lasso Models",
       x = xlab, y = "|Bias|",
       color = "Method", shape = "Method") +
  facet_grid(n_covs ~ R2_y, labeller = facet_labs) +
  scale_color_manual(labels = base_labs, values = base_cols) +
  scale_shape_manual(labels = base_labs, values = base_shapes) +
  theme(axis.title.y = element_text(angle = 0, vjust = 0.5, hjust = 0.5))

bias_base_all <- res %>%
  filter(dgp == "binary") %>%
  filter(method %in% all_meths) %>%
  ggplot(aes(x = R2_d, y = bias, group = method, color = method,
               shape = method)) +
  labs(title = "Absolute Bias",
       subtitle = "Binary Outcome with Logistic Lasso Models",
       x = xlab, y = "|Bias|",
       color = "Method", shape = "Method") +
  facet_grid(n_covs ~ R2_y, labeller = facet_labs) +
  scale_color_manual(labels = all_labs, values = all_cols) +
  scale_shape_manual(labels = all_labs, values = all_shapes) +
  theme(axis.title.y = element_text(angle = 0, vjust = 0.5, hjust = 0.5))


cov_base <- res_cov %>%
  filter(dgp == "binary") %>%
  filter(method %in% base_meths) %>%
  ggplot(aes(x = R2_d, y = coverage, group = method, color = method,
             shape = method)) +
  labs(title = "Coverage of 95% Confidence Intervals",
       subtitle = "Binary Outcome with Logistic Lasso Models",
       x = xlab, y = "Coverage",
       color = "Method", shape = "Method") +
  facet_grid(n_covs ~ R2_y, labeller = facet_labs) +
  scale_color_manual(labels = base_labs, values = base_cols) +
  scale_shape_manual(labels = base_labs, values = base_shapes) +
  theme(axis.title.y = element_text(angle = 0, vjust = 0.5, hjust = 0.5)) +
  geom_abline(intercept = 0.95, slope = 0, linetype = 2) +
  ylim(0, 1)



## main RMSE
cairo_pdf("figures/binary-rmse-sim.pdf", width = 8, height = 5,
          family = "Fira Sans")
rmse_base + geom_point(size = 3) + geom_line()
dev.off()
cairo_pdf("figures/binary-rmse-sim-all.pdf", width = 8, height = 5,
          family = "Fira Sans")
rmse_base_all + geom_point(size = 3) + geom_line()
dev.off()
cairo_pdf("figures/binary-bias-sim.pdf", width = 8, height = 5,
          family = "Fira Sans")
bias_base + geom_point(size = 3) + geom_line()
dev.off()
cairo_pdf("figures/binary-bias-sim-all.pdf", width = 8, height = 5,
          family = "Fira Sans")
bias_base_all + geom_point(size = 3) + geom_line()
dev.off()
cairo_pdf("figures/binary-cov-sim.pdf", width = 8, height = 5,
          family = "Fira Sans")
cov_base + geom_point(size = 3) + geom_line()
dev.off()
